{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# AutoQuant\n",
    "\n",
    "This notebook contains an example of how to use AIMET AutoQuant feature.\n",
    "\n",
    "AIMET offers a suite of neural network post-training quantization (PTQ) techniques that can be applied in succession. However, finding the right sequence of techniques to apply is time-consuming and can be challenging for non-expert users. We instead recommend AutoQuant to save time and effort.\n",
    "\n",
    "AutoQuant is an API that analyzes the model and automatically applies various PTQ techniques based on best-practices heuristics. You specify a tolerable accuracy drop, and AutoQuant applies PTQ techniques cumulatively until the target accuracy is satisfied.\n",
    "\n",
    "## Overall flow\n",
    "\n",
    "This example performs the following steps:\n",
    "\n",
    "1. Define constants and helper functions\n",
    "2. Load a pretrained FP32 model\n",
    "3. Run AutoQuant\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "This notebook does not show state-of-the-art results. For example, it uses a relatively quantization-friendly model (Resnet18). Also, some optimization parameters like number of fine-tuning epochs are chosen to improve execution speed in the notebook.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Dataset\n",
    "\n",
    "This example does image classification on the ImageNet dataset. If you already have a version of the data set, use that. Otherwise download the data set, for example from https://image-net.org/challenges/LSVRC/2012/index .\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "The dataloader provided in this example relies on these features of the ImageNet data set:\n",
    "\n",
    "- Subfolders `train` for the training samples and `val` for the validation samples. See the [pytorch dataset description](https://pytorch.org/vision/0.8/_modules/torchvision/datasets/imagenet.html) for more details.\n",
    "- One subdirectory per class, and one file per image sample.\n",
    "\n",
    "</div>\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "Note\n",
    "\n",
    "To speed up the execution of this notebook, you can use a reduced subset of the ImageNet dataset. For example: The entire ILSVRC2012 dataset has 1000 classes, 1000 training samples per class and 50 validation samples per class. However, for the purpose of running this notebook, you can reduce the dataset to, say, two samples per class.\n",
    "\n",
    "</div>\n",
    "\n",
    "Edit the cell below to specify the directory where the downloaded ImageNet dataset is saved."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "is_executing": true
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "from torchvision import transforms, datasets\n",
    "\n",
    "DATASET_DIR = '/path/to/dataset'   # Replace this path with a real directory\n",
    "\n",
    "val_transforms = transforms.Compose([\n",
    "    transforms.CenterCrop(224),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),\n",
    "])\n",
    "\n",
    "imagenet_dataset = datasets.ImageFolder(root=os.path.join(DATASET_DIR, 'val'), transform=val_transforms)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Define Constants and Helper functions\n",
    "\n",
    "This section defines the following constants and helper functions:\n",
    "\n",
    "- **EVAL_DATASET_SIZE** A typical value is 5000. In this example, the value has been set to 500 for faster execution.\n",
    "- **CALIBRATION_DATASET_SIZE** A typical value is 2000. In this example, the value has been set to 200 for faster execution.\n",
    "- **_create_sampled_data_loader()** returns a DataLoader based on the dataset and the number of samples provided.\n",
    "- **eval_callback()** defines an evaluation function for the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import random\n",
    "from typing import Optional\n",
    "from tqdm import tqdm\n",
    "import torch\n",
    "from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset\n",
    "from aimet_torch.utils import in_eval_mode, get_device\n",
    "\n",
    "EVAL_DATASET_SIZE = 500\n",
    "CALIBRATION_DATASET_SIZE = 200\n",
    "\n",
    "_datasets = {}\n",
    "\n",
    "def _create_sampled_data_loader(dataset, num_samples):\n",
    "    if num_samples not in _datasets:\n",
    "        indices = random.sample(range(len(dataset)), num_samples)\n",
    "        _datasets[num_samples] = Subset(dataset, indices)\n",
    "    return DataLoader(_datasets[num_samples], batch_size=32)\n",
    "\n",
    "\n",
    "def eval_callback(model: torch.nn.Module, num_samples: Optional[int] = None) -> float:\n",
    "    if num_samples is None:\n",
    "        num_samples = EVAL_DATASET_SIZE\n",
    "\n",
    "    data_loader = _create_sampled_data_loader(imagenet_dataset, num_samples)\n",
    "    device = get_device(model)\n",
    "    \n",
    "    correct = 0\n",
    "    with in_eval_mode(model), torch.no_grad():\n",
    "        for image, label in tqdm(data_loader):\n",
    "            image = image.to(device)\n",
    "            label = label.to(device)\n",
    "            logits = model(image)\n",
    "            top1 = logits.topk(k=1).indices\n",
    "            correct += (top1 == label.view_as(top1)).sum()\n",
    "\n",
    "    return int(correct) / num_samples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load a pretrained FP32 model\n",
    "\n",
    "**Load a pretrained resnet18 model from torchvision.** \n",
    "\n",
    "You can load any pretrained PyTorch model instead."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from torchvision.models import resnet18\n",
    "\n",
    "model = resnet18(pretrained=True).eval()\n",
    "\n",
    "if torch.cuda.is_available():\n",
    "    model.to(torch.device('cuda'))\n",
    "\n",
    "accuracy = eval_callback(model)\n",
    "print(f'- FP32 accuracy: {accuracy}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Run AutoQuant\n",
    "\n",
    "**3.1 Create an AutoQuant object.**\n",
    "\n",
    "The AutoQuant feature uses an unlabeled dataset to quantize the model. The **UnlabeledDatasetWrapper** class creates an unlabeled Dataset object from a labeled Dataset. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_torch.v1.auto_quant import AutoQuant\n",
    "\n",
    "class UnlabeledDatasetWrapper(Dataset):\n",
    "    def __init__(self, dataset):\n",
    "        self._dataset = dataset\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self._dataset)\n",
    "\n",
    "    def __getitem__(self, index):\n",
    "        images, _ = self._dataset[index]\n",
    "        return images\n",
    "\n",
    "\n",
    "unlabeled_imagenet_dataset = UnlabeledDatasetWrapper(imagenet_dataset)\n",
    "unlabeled_imagenet_data_loader = _create_sampled_data_loader(unlabeled_imagenet_dataset,\n",
    "                                                             CALIBRATION_DATASET_SIZE)\n",
    "\n",
    "dummy_input = torch.randn((1, 3, 224, 224)).to(get_device(model))\n",
    "\n",
    "auto_quant = AutoQuant(model,\n",
    "                       dummy_input=dummy_input,\n",
    "                       data_loader=unlabeled_imagenet_data_loader,\n",
    "                       eval_callback=eval_callback)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**3.2 Run AutoQuant inference**.\n",
    "\n",
    "AutoQuant inference uses the **eval_callback** with the generic quantized model without applying PTQ techniques. This provides a baseline evaluation score before running AutoQuant optimization."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sim, initial_accuracy = auto_quant.run_inference()\n",
    "print(f\"- Quantized Accuracy (before optimization): {initial_accuracy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**3.3 Set AdaRound Parameters (optional)**.\n",
    "\n",
    "AutoQuant uses predefined default parameters for AdaRound.\n",
    "These values were determined empirically and work well with the common models.\n",
    "\n",
    "If necessary, you can use custom parameters for Adaround.\n",
    "This example uses very small AdaRound parameters for faster execution."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from aimet_torch.v1.adaround.adaround_weight import AdaroundParameters\n",
    "\n",
    "ADAROUND_DATASET_SIZE = 200\n",
    "adaround_data_loader = _create_sampled_data_loader(unlabeled_imagenet_dataset, ADAROUND_DATASET_SIZE)\n",
    "adaround_params = AdaroundParameters(adaround_data_loader, num_batches=len(adaround_data_loader), default_num_iterations=2000)\n",
    "auto_quant.set_adaround_params(adaround_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**3.4 Run AutoQuant Optimization**.\n",
    "\n",
    "This step runs AutoQuant optimization. AutoQuant returns the following:\n",
    "- The best possible quantized model\n",
    "- The corresponding evaluation score\n",
    "- The path to the encoding file\n",
    "\n",
    "The **allowed_accuracy_drop** indicates the tolerable accuracy drop. AutoQuant applies a series of quantization features until the target accuracy (FP32 accuracy - allowed accuracy drop) is satisfied. When the target accuracy is reached, AutoQuant returns immediately without applying furhter PTQ techniques. See the [AutoQuant User Guide](https://quic.github.io/aimet-pages/releases/latest/user_guide/auto_quant.html) and [AutoQuant API documentation](https://quic.github.io/aimet-pages/releases/latest/api_docs/torch_auto_quant.html) for details."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "model, optimized_accuracy, encoding_path = auto_quant.optimize(allowed_accuracy_drop=0.01)\n",
    "print(f\"- Quantized Accuracy (after optimization):  {optimized_accuracy}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "## Next steps\n",
    "\n",
    "The next step is to export this model for installation on the target.\n",
    "\n",
    "**Export the model and encodings.**\n",
    "\n",
    "- Export the model with the updated weights but without the fake quant ops. \n",
    "- Export the encodings (scale and offset quantization parameters). AIMET QuantizationSimModel provides an export API for this purpose.\n",
    "\n",
    "The following code performs these exports."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "os.makedirs('./output/', exist_ok=True)\n",
    "dummy_input = dummy_input.cpu()\n",
    "sim.export(path='./output/', filename_prefix='resnet18_after_cle_bc', dummy_input=dummy_input)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## For more information\n",
    "\n",
    "See the [AIMET API docs](https://quic.github.io/aimet-pages/AimetDocs/api_docs/index.html) for details about the AIMET APIs and optional parameters.\n",
    "\n",
    "See the [other example notebooks](https://github.com/quic/aimet/tree/develop/Examples/torch/quantization) to learn how to use other AIMET post-training quantization techniques."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
